This machine learning exercise is documented by following a project done by AlverniaEP. Look to her project for the original verion.
The project is related to a ride sharing service in Turkey. The objective is to forecast the driver demand for the next 5 days. There are 3 forecasting method used: - est - stlm - tbats For knowledge improvement, best to read this book released by the package forecast author.
Call the packages and the required dataset. The provided data has features as follow: - timeStamp: Order time - driverID: Driver ID - riderID: Rider ID - orderStatus: - confirmedTimeSec: - srcGeohash: Geographical location group(?) - srcLong: Source location longitude - scrLat: Source location latitude - destLong: Destination location longitude - destLat: Destination location latitude
library(tidyverse)
## -- Attaching packages -------------------------------------------------------------------------------------------- tidyverse 1.3.0 --
## v ggplot2 3.3.2 v purrr 0.3.4
## v tibble 3.0.2 v dplyr 1.0.0
## v tidyr 1.1.0 v stringr 1.4.0
## v readr 1.3.1 v forcats 0.5.0
## -- Conflicts ----------------------------------------------------------------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(lubridate)
##
## Attaching package: 'lubridate'
## The following objects are masked from 'package:base':
##
## date, intersect, setdiff, union
library(magrittr)
##
## Attaching package: 'magrittr'
## The following object is masked from 'package:purrr':
##
## set_names
## The following object is masked from 'package:tidyr':
##
## extract
library(forecast)
## Registered S3 method overwritten by 'quantmod':
## method from
## as.zoo.data.frame zoo
library(ggthemes)
library(plotly)
##
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
##
## last_plot
## The following object is masked from 'package:stats':
##
## filter
## The following object is masked from 'package:graphics':
##
## layout
rideSharing <- read_csv(choose.files())
## Parsed with column specification:
## cols(
## timeStamp = col_datetime(format = ""),
## driverID = col_character(),
## riderID = col_character(),
## orderStatus = col_character(),
## confirmedTimeSec = col_double(),
## srcGeohash = col_character(),
## srcLong = col_double(),
## srcLat = col_double(),
## destLong = col_double(),
## destLat = col_double()
## )
head(rideSharing)
## Warning: `...` is not empty.
##
## We detected these problematic arguments:
## * `needs_dots`
##
## These dots only exist to allow future extensions and should be empty.
## Did you misspecify an argument?
## # A tibble: 6 x 10
## timeStamp driverID riderID orderStatus confirmedTimeSec srcGeohash
## <dttm> <chr> <chr> <chr> <dbl> <chr>
## 1 2017-11-03 19:02:31 <NA> 594108~ nodrivers 0 6
## 2 2017-10-01 17:45:56 <NA> 59d0d4~ nodrivers 0 7
## 3 2017-10-01 17:46:01 <NA> 59d0d4~ nodrivers 0 7
## 4 2017-10-01 19:07:09 <NA> 59319c~ nodrivers 0 7
## 5 2017-10-01 19:45:52 <NA> 59898b~ nodrivers 0 7
## 6 2017-10-02 01:39:52 <NA> 59c2cd~ nodrivers 0 7
## # ... with 4 more variables: srcLong <dbl>, srcLat <dbl>, destLong <dbl>,
## # destLat <dbl>
srcGeohash and orderStatus By looking to the grouping summary. It looks like that most of the order is centralized in s srcGeohash area. Other area most likely has nodrivers or canceled.rideSharing %>%
group_by(srcGeohash, orderStatus) %>%
summarise(total_status = n())
## `summarise()` regrouping output by 'srcGeohash' (override with `.groups` argument)
## Warning: `...` is not empty.
##
## We detected these problematic arguments:
## * `needs_dots`
##
## These dots only exist to allow future extensions and should be empty.
## Did you misspecify an argument?
## # A tibble: 21 x 3
## # Groups: srcGeohash [13]
## srcGeohash orderStatus total_status
## <chr> <chr> <int>
## 1 6 nodrivers 1
## 2 7 cancelled 26
## 3 7 confirmed 1
## 4 7 nodrivers 512
## 5 9 cancelled 1
## 6 9 nodrivers 11
## 7 c nodrivers 2
## 8 d nodrivers 8
## 9 e cancelled 2
## 10 e nodrivers 142
## # ... with 11 more rows
orderStatus As mentioned by the first researcher, the orderStatus feature has “cancelled” and duplicated “nodriver”. BY looking to the method used to filter out the “nodriver”, assumption taken on the method would be:orderData <- rideSharing %>%
filter(orderStatus != "cancelled") %>%
mutate(Date = timeStamp %>% format("%Y-%m-%d"),
Hour = timeStamp %>% hour()) %>%
# arrange and group by : riders, date, and hour
arrange(riderID, Date, Hour) %>%
group_by(riderID, Date, Hour) %>%
mutate(prevOrder = lag(orderStatus, default = "NA"),
dupOrder = ifelse(orderStatus == prevOrder, "Duplicates", "Not Duplicates"),
orderStatus = ifelse(orderStatus == "nodrivers" & dupOrder == "Duplicates", NA, orderStatus)) %>%
# drop NA in orderStatus
filter(!is.na(orderStatus)) %>%
# We're left with "nodrivers" and "confirmed", we still have to throw the "nodrivers" to make it 1 legitimate order
mutate(nextOrder = lead(orderStatus, default = "NA"),
orderStatus = ifelse(orderStatus == "nodrivers" & nextOrder == "confirmed", NA, orderStatus)) %>%
filter(!is.na(orderStatus)) %>%
# stop the grouping
ungroup() %>%
# Get the requiered data for order forecasting
select(timeStamp, Date, Hour, riderID, orderStatus)
head(orderData)
## Warning: `...` is not empty.
##
## We detected these problematic arguments:
## * `needs_dots`
##
## These dots only exist to allow future extensions and should be empty.
## Did you misspecify an argument?
## # A tibble: 6 x 5
## timeStamp Date Hour riderID orderStatus
## <dttm> <chr> <int> <chr> <chr>
## 1 2017-11-18 13:18:06 2017-11-18 13 58c420322eb67e3311540073 confirmed
## 2 2017-11-25 21:48:23 2017-11-25 21 58c52e464a1df942e6052816 confirmed
## 3 2017-10-01 14:19:09 2017-10-01 14 58c6e450277616158a33c1fe confirmed
## 4 2017-10-08 14:51:13 2017-10-08 14 58c6e450277616158a33c1fe nodrivers
## 5 2017-10-08 18:28:59 2017-10-08 18 58c6e450277616158a33c1fe confirmed
## 6 2017-10-11 13:16:24 2017-10-11 13 58c6e450277616158a33c1fe confirmed
demandData <- orderData %>%
group_by(Date, Hour) %>%
summarise(Demand = n()) %>%
ungroup() %>%
complete(Date, Hour) %>%
mutate(Demand = ifelse(is.na(Demand), 0, Demand)) %>%
mutate(Date = as.Date(Date)) %>%
mutate(timeStamp = as.POSIXct(paste(Date, Hour), format = "%Y-%m-%d %H"))
## `summarise()` regrouping output by 'Date' (override with `.groups` argument)
head(demandData)
## Warning: `...` is not empty.
##
## We detected these problematic arguments:
## * `needs_dots`
##
## These dots only exist to allow future extensions and should be empty.
## Did you misspecify an argument?
## # A tibble: 6 x 4
## Date Hour Demand timeStamp
## <date> <int> <dbl> <dttm>
## 1 2017-10-01 0 0 2017-10-01 00:00:00
## 2 2017-10-01 1 0 2017-10-01 01:00:00
## 3 2017-10-01 2 0 2017-10-01 02:00:00
## 4 2017-10-01 3 67 2017-10-01 03:00:00
## 5 2017-10-01 4 31 2017-10-01 04:00:00
## 6 2017-10-01 5 18 2017-10-01 05:00:00
demandData_ts <- ts(demandData$Demand, start = min(demandData$Date), frequency = 24)
demandData_ts %>%
head(24 * 7 * 4) %>%
decompose() %>%
autoplot() + # Forecast library
theme_bw()
demandData_tsSsn <- msts(demandData$Demand, seasonal.periods = c(24, 24 * 7))
demandData_tsSsn %>%
head(24 * 7 * 4) %>%
mstl() %>%
autoplot() +
theme_bw()
demandData %>%
mutate(Day = weekdays(Date)) %>%
ggplot(aes(x = Hour,
y = Demand,
fill = as.factor(Day))) +
geom_bar(stat = "identity",
position = position_dodge()) +
theme_bw() +
labs(fill = "",
title = "Total Demand at Hour of the Day")
constant <- 1
inputForecast <- head(demandData_ts + constant, length(demandData_ts) - 24*6)
inputForecast_Sn <- head(demandData_tsSsn + constant, length(demandData_tsSsn) - 24*6)
y_test <- tail(demandData_ts, 24*6)
# Transform constant
removeConstant <- function(forecastResult, constant = 1){
forecastResult[["mean"]] <- forecastResult[["mean"]] - constant
forecastResult[["upper"]] <- forecastResult[["upper"]] - constant
forecastResult[["lower"]] <- forecastResult[["lower"]] - constant
return(forecastResult)
}
tsForecast <- inputForecast %>%
ets(lambda = 0) %>%
forecast(h = 24*6) %>%
removeConstant()
mstsForecast <- inputForecast_Sn %>%
stlm(lambda = 0) %>%
forecast(h = 24*6) %>%
removeConstant()
tbatsForecast <- inputForecast_Sn %>%
log() %>%
tbats(use.box.cox = FALSE)
tbatsForecast$lambda <- 0
tbatsForecast %<>%
forecast(h = 24*6) %>%
removeConstant()
## Warning in InvBoxCox(y.forecast, object$lambda, biasadj, list(level = level, :
## biasadj information not found, defaulting to FALSE.
rbind(accuracy(round(as.vector(tsForecast$mean)), y_test),
accuracy(round(as.vector(mstsForecast$mean)), y_test),
accuracy(round(as.vector(tbatsForecast$mean)), y_test)) %>%
as_tibble(rownames = "Model") %>%
mutate(Model = c("ets", "stlm", "tbats"))
## Warning: `...` is not empty.
##
## We detected these problematic arguments:
## * `needs_dots`
##
## These dots only exist to allow future extensions and should be empty.
## Did you misspecify an argument?
## # A tibble: 3 x 8
## Model ME RMSE MAE MPE MAPE ACF1 `Theil's U`
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 ets -10.3 82.5 52.2 -Inf Inf 0.764 0
## 2 stlm -23.1 69.2 44.4 -Inf Inf 0.596 0
## 3 tbats -27.9 67.8 48.6 -Inf Inf 0.609 0
forecastCompare <- data.frame(Timestamp = tail(demandData$timeStamp, 24*6),
Actual = y_test,
tsModel = as.vector(tsForecast$mean),
stmlModel = as.vector(mstsForecast$mean),
tbatsModel = as.vector(tbatsForecast$mean))
ggplotly(
forecastCompare %>%
gather(key, value, Actual, tsModel, stmlModel, tbatsModel) %>%
ggplot(aes(x = Timestamp,
y = value,
color = key)) +
geom_line() +
theme_bw()
)
## Warning: attributes are not identical across measure variables;
## they will be dropped
forecastCompare %>%
mutate(tsError = Actual - tsModel,
stmlError = Actual - stmlModel,
tbatsError = Actual - tbatsModel) %>%
select(Timestamp, matches("Error")) %>%
gather(key, value, matches("Error")) %>%
ggplot(aes(x = Timestamp,
y = value,
color = key,
shape = key)) +
geom_point(alpha = 0.4, size = 3) +
theme_bw()
ggplotly(forecastCompare %>%
mutate(tsError = Actual - tsModel,
stmlError = Actual - stmlModel,
tbatsError = Actual - tbatsModel) %>%
select(Timestamp, matches("Error")) %>%
gather(key, value, matches("Error")) %>%
ggplot(aes(x = Timestamp,
y = value,
color = key,
shape = key)) +
geom_point(alpha = 0.4, size = 3) +
theme_bw()
)
fets <- function(input, fun, lambda = 0, h = 10, addDependentVar = 1){
# Function to add constant avoid Infinite log transformation value
removeConstant <- function(forecastResult, constant = addDependentVar){
forecastResult[["mean"]] <- forecastResult[["mean"]] - constant
forecastResult[["upper"]] <- forecastResult[["upper"]] - constant
forecastResult[["lower"]] <- forecastResult[["lower"]] - constant
return(forecastResult)
}
input %>%
fun(lambda = lambda) %>%
forecast(h = h) %>%
removeConstant()
}